#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------

import argparse
import onnxruntime as onnxrt
import numpy as np
import os
import sys
from timeit import default_timer as timer

float_dict = {
    'tensor(float16)': 'float16',
    'tensor(float)': 'float32',
    'tensor(double)': 'float64'
}

integer_dict = {
    'tensor(int32)': 'int32',
    'tensor(int8)': 'int8',
    'tensor(uint8)': 'uint8',
    'tensor(int16)': 'int16',
    'tensor(uint16)': 'uint16',
    'tensor(int64)': 'int64',
    'tensor(uint64)': 'uint64'
}

# simple test program for loading onnx model, feeding all inputs and running the model num_iters times.


def main():
    parser = argparse.ArgumentParser(description='Simple ONNX Runtime Test Tool.')
    parser.add_argument('model_path', help='model path')
    parser.add_argument('num_iters', nargs='?', type=int,
                        default=1000, help='model run iterations. default=1000')
    parser.add_argument('--debug', action='store_true',
                        help='pause execution to allow attaching a debugger.')
    parser.add_argument('--profile', action='store_true',
                        help='enable chrome timeline trace profiling.')
    args = parser.parse_args()
    iters = args.num_iters

    if args.debug:
        print("Pausing execution ready for debugger to attach to pid: {}".format(
            os.getpid()))
        print("Press key to continue.")
        sys.stdin.read(1)

    sess_options = None
    if args.profile:
        sess_options = onnxrt.SessionOptions()
        sess_options.enable_profiling = True
        sess_options.profile_file_prefix = os.path.basename(args.model_path)

    sess = onnxrt.InferenceSession(args.model_path, sess_options)
    meta = sess.get_modelmeta()

    feeds = {}
    for input_meta in sess.get_inputs():
        # replace any symbolic dimensions (value is None) with 1
        shape = [dim if dim else 1 for dim in input_meta.shape]
        if input_meta.type in float_dict:
            feeds[input_meta.name] = np.random.rand(
                *shape).astype(float_dict[input_meta.type])
        elif input_meta.type in integer_dict:
            feeds[input_meta.name] = np.random.uniform(
                high=1000, size=tuple(shape)).astype(integer_dict[input_meta.type])
        elif input_meta.type == 'tensor(bool)':
            feeds[input_meta.name] = np.random.randint(
                2, size=tuple(shape)).astype('bool')
        else:
            print("unsupported input type {} for input {}".format(
                input_meta.type, input_meta.name))
            sys.exit(-1)

    # Starting with IR4 some initializers provide default values
    # and can be overridden (available in IR4). For IR < 4 models
    # the list would be empty
    for initializer in sess.get_overridable_initializers():
        shape = [dim if dim else 1 for dim in initializer.shape]
        if initializer.type in float_dict:
            feeds[initializer.name] = np.random.rand(
                *shape).astype(float_dict[initializer.type])
        elif initializer.type in integer_dict:
            feeds[initializer.name] = np.random.uniform(
                high=1000, size=tuple(shape)).astype(integer_dict[initializer.type])
        elif initializer.type == 'tensor(bool)':
            feeds[initializer.name] = np.random.randint(
                2, size=tuple(shape)).astype('bool')
        else:
            print("unsupported initializer type {} for initializer {}".format(
                initializer.type, initializer.name))
            sys.exit(-1)

    start = timer()
    for i in range(iters):
        sess.run([], feeds)  # fetch all outputs
    end = timer()

    print("model: {}".format(meta.graph_name))
    print("version: {}".format(meta.version))
    print("iterations: {}".format(iters))
    print("avg latency: {} ms".format(((end - start)*1000)/iters))

    if args.profile:
        trace_file = sess.end_profiling()
        print("trace file written to: {}".format(trace_file))

    return 0


if __name__ == "__main__":
    sys.exit(main())
